# coding: utf-8

from gensim import corpora
from gensim import models
import matplotlib.pyplot as plt


def get_train_set(path='./data_dir/article', max_count=1000):
    fr = open(path, 'r', encoding='utf-8')
    train = []
    count = 0
    for line in fr.readlines():
        line = line.strip().split(' ')
        if len(line) > 200:
            train.append(line)
            count = count + 1
        if count == max_count:
            break
    return train


def get_dictionary(train_set):
    # train_set是二维的列表，类似[ ['zhang', 'wang', 'li'],['zhang', 'wang'] ]
    # 每一行是一篇文章的分词，也就是一个document
    dictionary = corpora.Dictionary(train_set)
    # 词过滤函数，也就是df过滤
    dictionary.filter_extremes(no_below=200, no_above=0.1)
    # dictionary.filter_tokens(bad_ids=[dictionary.token2id.get('婆婆')])
    return dictionary


def get_bow_corpus(dictionary, train_set):
    return [dictionary.doc2bow(text) for text in train_set]


def get_tfidf_model(dictionary, corpus):
    return models.TfidfModel(corpus=corpus, dictionary=dictionary)
